feat: Add SPSA optimization method (Issue #357)#1712
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
d304c4b to
1bc8449
Compare
| def spsa_standard_schedule( | ||
| init_value: float, | ||
| decay_rate: float, | ||
| offset: float = 0.0, |
There was a problem hiding this comment.
If a user instantiates this schedule with the defaults, the very first step (count=0) will result in a ZeroDivisionError (or yield inf in JAX). Change the default offset to something mathematically stable, or at least enforce that offset > 0 if count starts at 0.
| grad_estimate = jax.tree.map( | ||
| lambda d: (y_plus - y_minus) / (2.0 * c) * d, delta | ||
| ) |
There was a problem hiding this comment.
You are recalculating for every single leaf in the PyTree. y_plus, y_minus, and c are all scalars. Calculate this scalar coefficient once outside the tree map, then just apply the multiplication
There was a problem hiding this comment.
Do this instead:
scalar_diff = (y_plus - y_minus) / (2.0 * c)
grad_estimate = jax.tree.map(lambda d: scalar_diff * d, delta)| # equivalent | ||
| # to multiplying by delta_i. We multiply for numerical stability. | ||
| grad_estimate = jax.tree.map( | ||
| lambda d: (y_plus - y_minus) / (2.0 * c) * d, delta |
There was a problem hiding this comment.
We need numerical safety here. If c decays to exactly 0 or gets sufficiently small this division will explode.
| self.assertAlmostEqual(val_0, 1.0 / (10.0**0.5)) | ||
| self.assertAlmostEqual(val_10, 1.0 / (20.0**0.5)) |
There was a problem hiding this comment.
Stick to np.testing.assert_allclose
|
Thanks for the thorough review, @servusdei2018! I've pushed a new commit addressing all of your points:
All tests pass perfectly locally. Let me know if everything looks good on your end! |
Addresses #357
Description
This PR implements the Simultaneous Perturbation Stochastic Approximation (SPSA) gradient estimator to address the open feature request #357.
Rather than implementing it as a stateful
optaxoptimizer, it is implemented as a composable gradient estimator (optax.contrib.spsa_estimator). This aligns best with JAX's functional paradigm, allowing users to pass the resultinggrad_fndirectly into any existingoptaxoptimizer (SGD, Adam, etc.) andoptax.chain. Standard polynomial schedules for learning rate and perturbation scaling are also provided.Verification
I have added rigorous unit tests in
tests/contrib/spsa_test.pyutilizingchex.all_variants:optax.sgdminimizing a noisy objective over 50 steps.jax.jitandjax.vmap.